{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Flash Evaluation on the DARPA OpTC Dataset\n",
    "\n",
    "This notebook is designed for evaluating Flash on the DARPA OpTC dataset. The OpTC dataset is a node-level dataset, crucial for our analysis. Flash is configured to operate in a node-level setting to effectively assess this dataset. The OpTC dataset is enriched with node attributes, making it suitable for running Flash in a decoupled manner. This includes using offline GNN embeddings and a downstream classifier. Our approach tests Flash on this dataset, where Flash generates word2vec embeddings as feature vectors for GNN embeddings. These embeddings are stored in a datastore and used in conjunction with a downstream model for improved detection results.\n",
    "\n",
    "## Accessing the Dataset:\n",
    "- The OpTC dataset can be accessed via this link: [OpTC Dataset](https://drive.google.com/drive/u/0/folders/1n3kkS3KR31KUegn42yk3-e6JkZvf0Caa).\n",
    "- Dataset files for evaluation will be downloaded automatically by the script.\n",
    "- While we provide pre-trained weights, you also have the option to download benign data files for training the models from the ground up.\n",
    "\n",
    "## Data Parsing and Execution:\n",
    "- The script is adept at autonomously parsing the downloaded data files.\n",
    "- For evaluation results, execute all cells in this notebook.\n",
    "\n",
    "## Model Training and Execution Options:\n",
    "- By default, the notebook utilizes pre-trained model weights.\n",
    "- It also offer settings to independently train Graph Neural Networks (GNNs), word2vec, and Xgboost models.\n",
    "- These independently trained models can then be deployed for an evaluation of the system.\n",
    "\n",
    "Following these guidelines will ensure a thorough and effective analysis of the OpTC dataset using Flash."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import os\n",
    "import torch.nn.functional as F\n",
    "import pickle\n",
    "import json\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import TSNE\n",
    "warnings.filterwarnings('ignore')\n",
    "from torch_geometric.loader import NeighborLoader\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "gnn_weights = \"trained_weights/optc/gnn_temp.pth\"\n",
    "xgboost_weights = \"trained_weights/optc/xgb.pkl\"\n",
    "word2vec_weights = 'w2v_optc.model'\n",
    "create_store = False\n",
    "gnnTrain = False\n",
    "xgbTrain = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "import gzip\n",
    "from sklearn.manifold import TSNE\n",
    "import json\n",
    "import copy\n",
    "import os\n",
    "import xgboost as xgb\n",
    "\n",
    "import gensim\n",
    "from gensim.models import Word2Vec\n",
    "from multiprocessing import Pool\n",
    "from itertools import compress\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "import multiprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import gensim\n",
    "from gensim.models.doc2vec import Doc2Vec, TaggedDocument\n",
    "from collections import Counter\n",
    "from gensim.models import Word2Vec\n",
    "from multiprocessing import Pool\n",
    "from itertools import compress\n",
    "from tqdm import tqdm\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import gzip\n",
    "import io\n",
    "\n",
    "def extract_logs(filepath, hostid):\n",
    "    search_pattern = f'SysClient{hostid}'\n",
    "    output_filename = f'SysClient{hostid}.systemia.com.txt'\n",
    "    \n",
    "    with gzip.open(filepath, 'rt', encoding='utf-8') as fin:\n",
    "        with open(output_filename, 'ab') as f:\n",
    "            out = io.BufferedWriter(f)\n",
    "            for line in fin:\n",
    "                if search_pattern in line:\n",
    "                    out.write(line.encode('utf-8'))\n",
    "            out.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gdown\n",
    "from tqdm import tqdm\n",
    "    \n",
    "def prepare_test_set():\n",
    "    urls = [\n",
    "        \"https://drive.google.com/file/d/1HFSyvmgH0jvdnnnTdKfWRjZYOrLWoIkv/view?usp=drive_link\",\n",
    "        \"https://drive.google.com/file/d/1pJLxJsDV8sngiedbfVajMetczIgM3PQd/view?usp=drive_link\",\n",
    "        \"https://drive.google.com/file/d/1fRQqc68r8-z5BL7H_eAKIDOeHp7okDuM/view?usp=drive_link\",\n",
    "        \"https://drive.google.com/file/d/1VfyGr8wfSe8LBIHBWuYBlU8c2CyEgO5C/view?usp=drive_link\",\n",
    "        \"https://drive.google.com/file/d/10N9ZPolq_L8HivBqzf_jFKbwjSxddsZp/view?usp=drive_link\",\n",
    "        \"https://drive.google.com/file/d/1xIr8gw-4zc8ESjUpYtrFsbOwhPGUSd15/view?usp=drive_link\",\n",
    "        \"https://drive.google.com/file/d/1PvlCp2oQaxEBEFGSQWfcFVj19zLOe7yH/view?usp=drive_link\"\n",
    "    ]\n",
    "\n",
    "    for url in urls:\n",
    "        gdown.download(url, quiet=False, use_cookies=False, fuzzy=True)\n",
    "\n",
    "    log_files = [\n",
    "        (\"AIA-201-225.ecar-2019-12-08T11-05-10.046.json.gz\", \"0201\"),\n",
    "        (\"AIA-201-225.ecar-last.json.gz\", \"0201\"),\n",
    "        (\"AIA-501-525.ecar-2019-11-17T04-01-58.625.json.gz\", \"0501\"),\n",
    "        (\"AIA-501-525.ecar-last.json.gz\", \"0501\"),\n",
    "        (\"AIA-51-75.ecar-last.json.gz\", \"0051\")\n",
    "    ]\n",
    "    \n",
    "    os.system(\"rm SysClient0201.com.txt\")\n",
    "    os.system(\"rm SysClient0501.com.txt\")\n",
    "    os.system(\"rm SysClient0051.com.txt\")\n",
    "\n",
    "    for file, code in tqdm(log_files, desc=\"Extracting logs\", unit=\"file\"):\n",
    "        extract_logs(file, code)\n",
    "\n",
    "prepare_test_set()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def is_valid_entry(entry):\n",
    "    valid_objects = {'PROCESS', 'FILE', 'FLOW', 'MODULE'}\n",
    "    invalid_actions = {'START', 'TERMINATE'}\n",
    "\n",
    "    object_valid = entry['object'] in valid_objects\n",
    "    action_valid = entry['action'] not in invalid_actions\n",
    "    actor_object_different = entry['actorID'] != entry['objectID']\n",
    "\n",
    "    return object_valid and action_valid and actor_object_different\n",
    "\n",
    "def Traversal_Rules(data):\n",
    "    filtered_data = {}\n",
    "\n",
    "    for entry in data:\n",
    "        if is_valid_entry(entry):\n",
    "            key = (\n",
    "                entry['action'], \n",
    "                entry['actorID'], \n",
    "                entry['objectID'], \n",
    "                entry['object'], \n",
    "                entry['pid'], \n",
    "                entry['ppid']\n",
    "            )\n",
    "            filtered_data[key] = entry\n",
    "\n",
    "    return list(filtered_data.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def Sentence_Construction(entry):\n",
    "    action = entry[\"action\"]\n",
    "    properties = entry['properties']\n",
    "    object_type = entry['object']\n",
    "\n",
    "    format_strings = {\n",
    "        'PROCESS': \"{parent_image_path} {action} {image_path} {command_line}\",\n",
    "        'FILE': \"{image_path} {action} {file_path}\",\n",
    "        'FLOW': \"{image_path} {action} {src_ip} {src_port} {dest_ip} {dest_port} {direction}\",\n",
    "        'MODULE': \"{image_path} {action} {module_path}\"\n",
    "    }\n",
    "\n",
    "    default_format = \"{image_path} {action} {module_path}\"\n",
    "\n",
    "    try:\n",
    "        format_str = format_strings.get(object_type, default_format)\n",
    "        phrase = format_str.format(action=action, **properties)\n",
    "    except KeyError:\n",
    "        phrase = ''\n",
    "\n",
    "    return phrase.split(' ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import json\n",
    "\n",
    "def Extract_Semantic_Info(event):\n",
    "    object_type = event['object']\n",
    "    properties = event['properties']\n",
    "\n",
    "    label_mapping = {\n",
    "        \"PROCESS\": ('parent_image_path', 'image_path'),\n",
    "        \"FILE\": ('image_path', 'file_path'),\n",
    "        \"MODULE\": ('image_path', 'module_path'),\n",
    "        \"FLOW\": ('image_path', 'dest_ip', 'dest_port')\n",
    "    }\n",
    "\n",
    "    label_keys = label_mapping.get(object_type, None)\n",
    "    if label_keys:\n",
    "        labels = [properties.get(key) for key in label_keys]\n",
    "        if all(labels):\n",
    "            event[\"actorname\"], event[\"objectname\"] = labels[0], ' '.join(labels[1:])\n",
    "            return event\n",
    "    return None\n",
    "\n",
    "def transform(text):\n",
    "    labeled_data = [event for event in (Extract_Semantic_Info(x) for x in text) if event]\n",
    "    data = Traversal_Rules(labeled_data)\n",
    "\n",
    "    phrases = [Sentence_Construction(x) for x in data if Sentence_Construction(x)]\n",
    "    for datum, phrase in zip(data, phrases):\n",
    "        datum['phrase'] = phrase\n",
    "\n",
    "    df = pd.DataFrame(data)\n",
    "    df['timestamp'] = pd.to_datetime(df['timestamp'].str[:-6], infer_datetime_format=True)\n",
    "    df.sort_values(by='timestamp', inplace=True)\n",
    "\n",
    "    return df\n",
    "\n",
    "def load_data(file_path):\n",
    "    with open(file_path, 'r') as file:\n",
    "        content = [json.loads(line) for line in file]\n",
    "    \n",
    "    return Featurize(transform(content))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def Featurize(df):\n",
    "    dummies = {'PROCESS': 0, 'FLOW': 1, 'FILE': 2, 'MODULE': 3}\n",
    "\n",
    "    nodes = {}\n",
    "    labels = {}\n",
    "    lblmap = {}\n",
    "    neimap = {}\n",
    "    edges = []\n",
    "\n",
    "    for index, row in df.iterrows():\n",
    "        actor_id, object_id = row['actorID'], row[\"objectID\"]\n",
    "        object_type = row['object']\n",
    "\n",
    "        nodes.setdefault(actor_id, []).extend(row['phrase'])\n",
    "        nodes.setdefault(object_id, []).extend(row['phrase'])\n",
    "\n",
    "        labels[actor_id] = dummies.get('PROCESS', -1)\n",
    "        labels[object_id] = dummies.get(object_type, -1)\n",
    "\n",
    "        lblmap[actor_id] = row['actorname']\n",
    "        lblmap[object_id] = row['objectname']\n",
    "\n",
    "        neimap.setdefault(actor_id, set()).add(row['objectname'])\n",
    "        neimap.setdefault(object_id, set()).add(row['actorname'])\n",
    "\n",
    "        edge_type = row['properties']['direction'] if object_type == 'FLOW' else row['action']\n",
    "        edges.append((actor_id, object_id, edge_type))\n",
    "\n",
    "    features, feat_labels, edge_index = [], [], [[], []]\n",
    "    node_index = {}\n",
    "\n",
    "    for node, phrases in nodes.items():\n",
    "        if not (len(phrases) == 1 and phrases[0] == 'DELETE'):\n",
    "            features.append(infer(phrases))\n",
    "            feat_labels.append(labels[node])\n",
    "            node_index[node] = len(features) - 1\n",
    "\n",
    "    for src, dst, _ in edges:\n",
    "        edge_index[0].append(node_index[src])\n",
    "        edge_index[1].append(node_index[dst])\n",
    "\n",
    "    mapp = list(node_index.keys())\n",
    "\n",
    "    return features, np.array(feat_labels), edge_index, mapp, lblmap, neimap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from gensim.models.callbacks import CallbackAny2Vec\n",
    "\n",
    "class EpochSaver(CallbackAny2Vec):\n",
    "\n",
    "    def __init__(self):\n",
    "        self.epoch = 0\n",
    "\n",
    "    def on_epoch_end(self, model):\n",
    "        model.save('trained_weights/optc/w2v_optc.model')\n",
    "        self.epoch += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class EpochLogger(CallbackAny2Vec):\n",
    "\n",
    "    def __init__(self):\n",
    "        self.epoch = 0\n",
    "\n",
    "    def on_epoch_begin(self, model):\n",
    "        print(\"Epoch #{} start\".format(self.epoch))\n",
    "\n",
    "    def on_epoch_end(self, model):\n",
    "        print(\"Epoch #{} end\".format(self.epoch))\n",
    "        self.epoch += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "from gensim.models import Word2Vec\n",
    "\n",
    "def prepare_sentences(df):\n",
    "    nodes = {}\n",
    "    for index, row in df.iterrows():\n",
    "        for key in ['actorID', 'objectID']:\n",
    "            node_id = row[key]\n",
    "            nodes.setdefault(node_id, []).extend(row['phrase'])\n",
    "    return list(nodes.values())\n",
    "\n",
    "def train_word2vec_model(train_file_path):\n",
    "    with open(train_file_path, 'r') as file:\n",
    "        content = [json.loads(line) for line in file]\n",
    "\n",
    "    events = transform(content)\n",
    "    phrases = prepare_sentences(events)\n",
    "\n",
    "    logger = EpochLogger()\n",
    "    saver = EpochSaver()\n",
    "    word2vec = Word2Vec(sentences=phrases, vector_size=20, window=5, min_count=1, workers=8, epochs=300, callbacks=[saver, logger])\n",
    "\n",
    "    return word2vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import numpy as np\n",
    "from gensim.models import Word2Vec\n",
    "\n",
    "class PositionalEncoder:\n",
    "\n",
    "    def __init__(self, d_model, max_len=100000):\n",
    "        position = torch.arange(max_len).unsqueeze(1)\n",
    "        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n",
    "        self.pe = torch.zeros(max_len, d_model)\n",
    "        self.pe[:, 0::2] = torch.sin(position * div_term)\n",
    "        self.pe[:, 1::2] = torch.cos(position * div_term)\n",
    "\n",
    "    def embed(self, x):\n",
    "        return x + self.pe[:x.size(0)]\n",
    "\n",
    "\n",
    "def infer(document):\n",
    "    word_embeddings = [w2vmodel.wv[word] for word in document if word in  w2vmodel.wv]\n",
    "    \n",
    "    if not word_embeddings:\n",
    "        return np.zeros(20)\n",
    "\n",
    "    output_embedding = torch.tensor(word_embeddings, dtype=torch.float)\n",
    "    if len(document) < 100000:\n",
    "        output_embedding = encoder.embed(output_embedding)\n",
    "\n",
    "    output_embedding = output_embedding.detach().cpu().numpy()\n",
    "    return np.mean(output_embedding, axis=0)\n",
    "\n",
    "encoder = PositionalEncoder(20)\n",
    "w2vmodel = Word2Vec.load(word2vec_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import SAGEConv\n",
    "\n",
    "class GCN(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GCN, self).__init__()\n",
    "        self.conv1 = SAGEConv(20, 32, normalize=True)\n",
    "        self.conv2 = SAGEConv(32, 20, normalize=True)\n",
    "        self.linear = nn.Linear(in_features=20, out_features=4)\n",
    "\n",
    "    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:\n",
    "    \n",
    "        x = self.encode(x, edge_index)\n",
    "        x = self.linear(x)\n",
    "        return F.softmax(x, dim=1)\n",
    "    \n",
    "    def encode(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:\n",
    "        \n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, p=0.5, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "from torch.nn import CrossEntropyLoss\n",
    "\n",
    "model = GCN().to(device)\n",
    "if not gnnTrain:\n",
    "    model.load_state_dict(torch.load(gnn_weights))\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.utils import class_weight\n",
    "\n",
    "if gnnTrain or create_store:\n",
    "    file_path = 'Enter Path to Train File Here'\n",
    "    nodes,labels,edges,mapp,lbl,nemap = load_data(file_path)\n",
    "\n",
    "    l = np.array(labels)\n",
    "    class_weights = class_weight.compute_class_weight(class_weight = \"balanced\",classes = np.unique(l),y = l)\n",
    "    class_weights = torch.tensor(class_weights,dtype=torch.float).to(device)\n",
    "    criterion = CrossEntropyLoss(weight=class_weights,reduction='mean')\n",
    "\n",
    "    graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch_geometric.loader import NeighborLoader\n",
    "\n",
    "def train_model(batch):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    predictions = model(batch.x, batch.edge_index)\n",
    "    loss = criterion(predictions, batch.y)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return loss.item(), batch.x.size(0)\n",
    "\n",
    "def evaluate_model(batch):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        predictions = model(batch.x, batch.edge_index)\n",
    "        pred_labels = predictions.argmax(dim=1)\n",
    "        correct_predictions = int((pred_labels == batch.y).sum())\n",
    "    return correct_predictions\n",
    "\n",
    "if gnnTrain:\n",
    "    loader = NeighborLoader(graph, num_neighbors=[-1, -1], batch_size=5000)\n",
    "\n",
    "    for epoch in range(100):\n",
    "        total_loss = total_correct = total_nodes = 0\n",
    "\n",
    "        for batch in loader:\n",
    "            loss, nodes = train_model(batch)\n",
    "            total_loss += loss\n",
    "            total_nodes += nodes\n",
    "            total_correct += evaluate_model(batch)\n",
    "\n",
    "        average_loss = total_loss / total_nodes\n",
    "        accuracy = total_correct / total_nodes\n",
    "\n",
    "        print(f\"Epoch #{epoch}. Training Loss: {average_loss:.5f}, Accuracy: {accuracy:.5f}\")\n",
    "        torch.save(model.state_dict(), gnn_weights)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if create_store:\n",
    "    model.eval()\n",
    "    out = model.encode(graph.x, graph.edge_index).tolist()\n",
    "    \n",
    "    gnn_map = {}\n",
    "    \n",
    "    for i in range(len(mapp)):\n",
    "        gnn_map[lbl[mapp[i]]] = (out[i],list(nemap[mapp[i]]))\n",
    "    \n",
    "    with open(\"data_files/emb_store.json\", \"w\") as file:\n",
    "        json.dump(gnn_map, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "with open(\"data_files/emb_store.json\", \"r\") as file:\n",
    "    gnn_map = json.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def load_features(filename=None, similarity=1):\n",
    "    nodes, y_train, edges, mapp, lbl, nemap = load_data(filename)\n",
    "    zero_vector = np.zeros(20)\n",
    "\n",
    "    X_train = []\n",
    "    for idx, map_item in enumerate(mapp):\n",
    "        label = lbl[map_item]\n",
    "        node_feature = nodes[idx]\n",
    "\n",
    "        if label in gnn_map:\n",
    "            emb, stored_set = gnn_map[label]\n",
    "            current_set = nemap[map_item]\n",
    "            jaccard_similarity = len(current_set.intersection(stored_set)) / len(current_set.union(stored_set))\n",
    "\n",
    "            feature_vector = emb if jaccard_similarity >= similarity else zero_vector\n",
    "        else:\n",
    "            feature_vector = zero_vector\n",
    "\n",
    "        X_train.append(np.hstack((node_feature, feature_vector)))\n",
    "\n",
    "    return np.array(X_train), y_train, edges, mapp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "from collections import Counter\n",
    "import xgboost as xgb\n",
    "\n",
    "if xgbTrain:\n",
    "    file_path = 'Enter Path to Train File Here'\n",
    "    x,y,_,_ = load_features(file_path)\n",
    "\n",
    "    xgb_cl = xgb.XGBClassifier()\n",
    "\n",
    "    xgb_cl.fit(x,y)\n",
    "    pickle.dump(xgb_cl, open(xgboost_weights, \"wb\"))\n",
    "\n",
    "    preds = xgb_cl.predict(x)\n",
    "    print(accuracy_score(y, preds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def load_pkl(fname):\n",
    "    with open(fname, 'rb') as f:\n",
    "        obj = pickle.load(f)\n",
    "    return obj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def validate(file_path):\n",
    "    x,y,_,_ = load_features(file_path)\n",
    "    xgb_cl = load_pkl(xgboost_weights)\n",
    "\n",
    "    pred = xgb_cl.predict(x)\n",
    "    proba = xgb_cl.predict_proba(x)\n",
    "\n",
    "    sorted = np.sort(proba, axis=1)\n",
    "    conf = (sorted[:,-1] - sorted[:,-2]) / sorted[:,-1]\n",
    "    conf = (conf - conf.min()) / conf.max()\n",
    "\n",
    "    check = (pred == y)\n",
    "    flag = ~torch.tensor(check)\n",
    "    scores = conf[flag].tolist()\n",
    "    return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from itertools import compress\n",
    "from torch_geometric import utils\n",
    "\n",
    "def Get_Adjacent(ids, mapp, edges, hops):\n",
    "    if hops == 0:\n",
    "        return set()\n",
    "    \n",
    "    neighbors = set()\n",
    "    for edge in zip(edges[0], edges[1]):\n",
    "        if any(mapp[node] in ids for node in edge):\n",
    "            neighbors.update(mapp[node] for node in edge)\n",
    "\n",
    "    if hops > 1:\n",
    "        neighbors = neighbors.union(Get_Adjacent(neighbors, mapp, edges, hops - 1))\n",
    "    \n",
    "    return neighbors\n",
    "\n",
    "def calculate_metrics(TP, FP, FN, TN):\n",
    "    FPR = FP / (FP + TN) if FP + TN > 0 else 0\n",
    "    TPR = TP / (TP + FN) if TP + FN > 0 else 0\n",
    "\n",
    "    prec = TP / (TP + FP) if TP + FP > 0 else 0\n",
    "    rec = TP / (TP + FN) if TP + FN > 0 else 0\n",
    "    fscore = (2 * prec * rec) / (prec + rec) if prec + rec > 0 else 0\n",
    "\n",
    "    return prec, rec, fscore, FPR, TPR\n",
    "\n",
    "def helper(MP, all_pids, GP, edges, mapp):\n",
    "    TP = MP.intersection(GP)\n",
    "    FP = MP - GP\n",
    "    FN = GP - MP\n",
    "    TN = all_pids - (GP | MP)\n",
    "\n",
    "    two_hop_gp = Get_Adjacent(GP, mapp, edges, 2)\n",
    "    two_hop_tp = Get_Adjacent(TP, mapp, edges, 2)\n",
    "    FPL = FP - two_hop_gp\n",
    "    TPL = TP.union(FN.intersection(two_hop_tp))\n",
    "    FN = FN - two_hop_tp\n",
    "\n",
    "    TP, FP, FN, TN = len(TPL), len(FPL), len(FN), len(TN)\n",
    "\n",
    "    prec, rec, fscore, FPR, TPR = calculate_metrics(TP, FP, FN, TN)\n",
    "    print(f\"Precision: {round(prec, 2)}, Recall: {round(rec, 2)}, Fscore: {round(fscore, 2)}\")\n",
    "    \n",
    "    return TPL, FPL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def load_features_test(dataframe, similarity_threshold=1):\n",
    "    nodes, y_train, edges, mapping, label_map, node_entity_map = Featurize(dataframe)\n",
    "    X_train = []\n",
    "\n",
    "    for i, map_id in enumerate(mapping):\n",
    "        label = label_map[map_id]\n",
    "        node_embedding = np.zeros(20)  \n",
    "\n",
    "        if label in gnn_map:\n",
    "            embedding, stored_set = gnn_map[label]\n",
    "            current_set = node_entity_map[map_id]\n",
    "            similarity_metric = len(current_set.intersection(stored_set)) / len(current_set.union(stored_set))\n",
    "\n",
    "            if similarity_metric >= similarity_threshold:\n",
    "                node_embedding = np.array(embedding)\n",
    "\n",
    "        X_train.append(np.hstack((nodes[i], node_embedding)))\n",
    "\n",
    "    return np.array(X_train), y_train, edges, mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch_geometric import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def load_events_from_hosts(hosts):\n",
    "    all_events = []\n",
    "    for host in hosts:\n",
    "        path = f'SysClient0{host}.systemia.com.txt'\n",
    "        with open(path, 'r') as file:\n",
    "            raw_events = [json.loads(line) for line in file]\n",
    "        all_events.extend(raw_events)\n",
    "    return all_events\n",
    "\n",
    "def load_ground_truth(gt_file):\n",
    "    with open(gt_file, 'r') as file:\n",
    "        gt_nodes = set(file.read().split())\n",
    "    return gt_nodes\n",
    "\n",
    "def evaluate_model(df, xgb_cl, similarity_threshold, confidence_threshold):\n",
    "    x, y, edges, mapp = load_features_test(df)\n",
    "\n",
    "    pred = xgb_cl.predict(x)\n",
    "    proba = xgb_cl.predict_proba(x)\n",
    "\n",
    "    sorted_proba = np.sort(proba, axis=1)\n",
    "    conf = (sorted_proba[:, -1] - sorted_proba[:, -2]) / sorted_proba[:, -1]\n",
    "    normalized_conf = (conf - conf.min()) / conf.max()\n",
    "\n",
    "    check = (pred == y) & (normalized_conf > confidence_threshold)\n",
    "    flag = ~torch.tensor(check)\n",
    "\n",
    "    index = utils.mask_to_index(flag).tolist()\n",
    "    ids = {mapp[idx] for idx in index}\n",
    "    return ids,edges,mapp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "def read_event_data(host):\n",
    "    file_path = f'SysClient0{host}.systemia.com.txt'\n",
    "    with open(file_path, 'r') as file:\n",
    "        return [json.loads(line) for line in file]\n",
    "        \n",
    "def stream_events(batch_size, window_size):\n",
    "    event_buffer = {}\n",
    "    hosts = ['051']\n",
    "    positions = {host: 0 for host in hosts}\n",
    "    while True:\n",
    "        for host in hosts:\n",
    "            if host not in event_buffer or len(event_buffer[host]) < positions[host] + batch_size:\n",
    "                events = read_event_data(host)\n",
    "                dframe = transform(events)\n",
    "                if host in event_buffer:\n",
    "                    event_buffer[host] = event_buffer[host].append(dframe, ignore_index=True)\n",
    "                else:\n",
    "                    event_buffer[host] = dframe\n",
    "            start = positions[host]\n",
    "            end = start + batch_size\n",
    "            yield event_buffer[host][start:end]\n",
    "            positions[host] += window_size\n",
    "            if positions[host] >= len(event_buffer[host]):\n",
    "                return\n",
    "\n",
    "def analyze_events(data_frame, ground_truth_nodes):\n",
    "    \n",
    "    if data_frame['properties'].apply(lambda x: isinstance(x, str)).any():\n",
    "        data_frame['properties'] = data_frame['properties'].apply(json.loads)\n",
    "        \n",
    "    actor_and_object_ids = set(data_frame['actorID']) | set(data_frame['objectID'])\n",
    "    relevant_ground_truth = {x for x in ground_truth_nodes if x in actor_and_object_ids}\n",
    "\n",
    "    features, labels, edges, mapping = load_features_test(data_frame)\n",
    "    model = load_pkl(xgboost_weights)\n",
    "\n",
    "    predictions = model.predict(features)\n",
    "    probabilities = model.predict_proba(features)\n",
    "\n",
    "    sorted_probabilities = np.sort(probabilities, axis=1)\n",
    "    confidence_scores = (sorted_probabilities[:, -1] - sorted_probabilities[:, -2]) / sorted_probabilities[:, -1]\n",
    "    normalized_confidence = (confidence_scores - confidence_scores.min()) / confidence_scores.max()\n",
    "\n",
    "    misclassified = ~torch.tensor(predictions == labels)\n",
    "    misclassified_indices = utils.mask_to_index(misclassified).tolist()\n",
    "    misclassified_ids = {mapping[idx] for idx in misclassified_indices}\n",
    "\n",
    "    helper(misclassified_ids, actor_and_object_ids, relevant_ground_truth, edges, mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def traverse(ids, mapping, edges, hops, visited=None):\n",
    "    if hops == 0:\n",
    "        return set()\n",
    "\n",
    "    if visited is None:\n",
    "        visited = set()\n",
    "\n",
    "    neighbors = set()\n",
    "    for src, dst in zip(edges[0], edges[1]):\n",
    "        src_mapped, dst_mapped = mapping[src], mapping[dst]\n",
    "\n",
    "        if (src_mapped in ids and dst_mapped not in visited) or \\\n",
    "           (dst_mapped in ids and src_mapped not in visited):\n",
    "            neighbors.add(src_mapped)\n",
    "            neighbors.add(dst_mapped)\n",
    "\n",
    "        visited.add(src_mapped)\n",
    "        visited.add(dst_mapped)\n",
    "\n",
    "    neighbors.difference_update(ids) \n",
    "    return ids.union(traverse(neighbors, mapping, edges, hops - 1, visited))\n",
    "\n",
    "def load_data(file_path):\n",
    "    with open(file_path, 'r') as file:\n",
    "        return json.load(file)\n",
    "\n",
    "def find_connected_alerts(start_alert, mapping, edges, depth, remaining_alerts):\n",
    "    connected_path = traverse({start_alert}, mapping, edges, depth)\n",
    "    return connected_path.intersection(remaining_alerts)\n",
    "\n",
    "def generate_incident_graphs(alerts, edges, mapping, depth):\n",
    "    incident_graphs = []\n",
    "    remaining_alerts = set(alerts)\n",
    "\n",
    "    while remaining_alerts:\n",
    "        alert = remaining_alerts.pop()\n",
    "        connected_alerts = find_connected_alerts(alert, mapping, edges, depth, remaining_alerts)\n",
    "\n",
    "        if len(connected_alerts) > 1:\n",
    "            incident_graphs.append(connected_alerts)\n",
    "            remaining_alerts -= connected_alerts\n",
    "\n",
    "    return incident_graphs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Testing Flash on OpTC Malicious Upgrade Attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_events = load_events_from_hosts(['051'])\n",
    "\n",
    "EnActIds = [x['actorID'] for x in all_events]\n",
    "EnObjIds = [x['objectID'] for x in all_events]\n",
    "EntitySet = set(EnActIds).union(set(EnObjIds))\n",
    "\n",
    "df = transform(all_events)\n",
    "\n",
    "gt_nodes = load_ground_truth('optc.txt')\n",
    "gt_nodes = [x for x in gt_nodes if x in EntitySet]\n",
    "gt_nodes = set(gt_nodes)\n",
    "\n",
    "xgboost_model = load_pkl(xgboost_weights)\n",
    "identified_ids,edges,mapp = evaluate_model(df, xgboost_model, 1, 0.6)\n",
    "\n",
    "alerts = helper(identified_ids, EntitySet, gt_nodes, edges, mapp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Testing Flash on OpTC Plain PowerShell Empire Attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_events = load_events_from_hosts(['201'])\n",
    "\n",
    "EnActIds = [x['actorID'] for x in all_events]\n",
    "EnObjIds = [x['objectID'] for x in all_events]\n",
    "EntitySet = set(EnActIds).union(set(EnObjIds))\n",
    "\n",
    "df = transform(all_events)\n",
    "\n",
    "gt_nodes = load_ground_truth('optc.txt')\n",
    "gt_nodes = [x for x in gt_nodes if x in EntitySet]\n",
    "gt_nodes = set(gt_nodes)\n",
    "\n",
    "xgboost_model = load_pkl(xgboost_weights)\n",
    "identified_ids,edges,mapp = evaluate_model(df, xgboost_model, 1, 0)\n",
    "\n",
    "alerts = helper(identified_ids, EntitySet, gt_nodes, edges, mapp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Testing Flash on OpTC Custom PowerShell Empire Attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_events = load_events_from_hosts(['501'])\n",
    "\n",
    "EnActIds = [x['actorID'] for x in all_events]\n",
    "EnObjIds = [x['objectID'] for x in all_events]\n",
    "EntitySet = set(EnActIds).union(set(EnObjIds))\n",
    "\n",
    "df = transform(all_events)\n",
    "\n",
    "gt_nodes = load_ground_truth('optc.txt')\n",
    "gt_nodes = [x for x in gt_nodes if x in EntitySet]\n",
    "gt_nodes = set(gt_nodes)\n",
    "\n",
    "xgboost_model = load_pkl(xgboost_weights)\n",
    "identified_ids,edges,mapp = evaluate_model(df, xgboost_model, 1, 0.98)\n",
    "\n",
    "alerts = helper(identified_ids, EntitySet, gt_nodes, edges, mapp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Testing Flash on Streaming Batches Generated from OpTC Attack Logs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stream = False\n",
    "if stream:\n",
    "    for data_frame in stream_events(250000, 250):\n",
    "        gt_nodes = load_ground_truth('optc.txt')\n",
    "        analyze_events(data_frame, gt_nodes)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PyTorch 1.12.0",
   "language": "python",
   "name": "pytorch-1.12.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
